'''
Description: 
version: 
Author: aps.auto
Date: 2024-06-26 11:39:40
LastEditors: Luoxl
LastEditTime: 2024-07-02 08:41:26
'''
from ultralytics import YOLO
import ultralytics
import torch
import numpy as np
import os
import cv2

def extract_masks(result, target_class):
    masks = []

    labels_index = result.boxes.cls.data.numpy()
    
    for i, index in enumerate(labels_index):
        
        if result.names[index] == target_class:
            masks.append(np.squeeze(result.masks[i].data.numpy())>0)
        
    return masks

import time
import glob
# Load a model
model = ultralytics.YOLO("yolov8n-seg.pt")  # load an official model
model = ultralytics.YOLO("D:/BaiduSyncdisk/Projects/yolov8-segment/yolov8n-seg-AGI.pt")  # load a custom model
    # 推理预热
input = torch.randn((1, 3, 480, 640))
model.predict(input)

# result = model.export(format='onnx') #yolov8原生转换
t1 = time.time()
model.predict('test.png')
print(f"总耗时{(time.time()-t1)*1000}ms")


# source = "D:/BaiduSyncdisk/Projects/yolov8-segment/robot_seg_dataset/images/test"
# png_files = glob.glob(os.path.join(source, '*.png'))
# print(f"共计 {len(png_files)}张图像")
# t1 = time.time()
# results = model.predict(source)  # predict on an image

# print('ok')
# masks = extract_masks(results[0], "cell phone")
# mask = masks[0].astype(np.uint8)  
# output = np.repeat(mask[:, :, np.newaxis], 3, axis=2)*255
# cv2.imwrite('output_mask.png', output)
# t2 = time.time()
# print(f"总耗时{(t2-t1)*1000}ms")
# print(f"推理一张图像耗时{1000*(t2-t1)/len(png_files)}ms")